{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ipSgY-L3z8kG"
      },
      "source": [
        "# Fine Tune Stable Diffusion\n",
        "\n",
        "Fine tuning Stable Diffusion on Pokemon, \n",
        "for more details see the [Lambda Labs examples repo](https://github.com/LambdaLabsML/examples). \n",
        "\n",
        "We recommend using a multi-GPU machine, for example an instance from [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud). If running on Colab this notebook is likely to need a GPU with >16GB of VRAM and a runtime with high RAM, which will almost certainly need Colab Pro or Pro+. (If you get errors suchs as `Killed` or `CUDA out of memory` then one of these is not sufficient)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bjNGOU6Pz8kH"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/justinpinkney/stable-diffusion.git\n",
        "%cd stable-diffusion\n",
        "!pip install --upgrade pip\n",
        "!pip install -r requirements.txt\n",
        "\n",
        "!pip install --upgrade keras # on lambda stack we need to upgrade keras\n",
        "!pip uninstall -y torchtext # on colab we need to remove torchtext"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m1AkWL270DSE"
      },
      "outputs": [],
      "source": [
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FTf_OdfEz8kI"
      },
      "outputs": [],
      "source": [
        "# Check the dataset\n",
        "from datasets import load_dataset\n",
        "ds = load_dataset(\"lambdalabs/pokemon-blip-captions\", split=\"train\")\n",
        "sample = ds[0]\n",
        "display(sample[\"image\"].resize((256, 256)))\n",
        "print(sample[\"text\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5V1uUVjQz8kI"
      },
      "source": [
        "To get the weights you need to you'll need to [go to the model card](https://huggingface.co/CompVis/stable-diffusion-v1-4-original), read the license and tick the checkbox if you agree."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PbMtzkytz8kI"
      },
      "outputs": [],
      "source": [
        "!pip install huggingface_hub\n",
        "from huggingface_hub import notebook_login\n",
        "\n",
        "notebook_login()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vnnbNNycz8kI"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import hf_hub_download\n",
        "ckpt_path = hf_hub_download(repo_id=\"CompVis/stable-diffusion-v-1-4-original\", filename=\"sd-v1-4-full-ema.ckpt\", use_auth_token=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y8VDyuQxz8kJ"
      },
      "source": [
        "Set your parameters below depending on your GPU setup, the settings below were used for training on a 2xA6000 machine, (the A6000 has 48GB of VRAM). On this set up good results are achieved in around 6 hours.\n",
        "\n",
        "You can make up for using smaller batches or fewer gpus by accumulating batches:\n",
        "\n",
        "`total batch size = batach size * n gpus * accumulate batches`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WVssEQJfz8kJ"
      },
      "outputs": [],
      "source": [
        "# 2xA6000:\n",
        "BATCH_SIZE = 4\n",
        "N_GPUS = 2\n",
        "ACCUMULATE_BATCHES = 1\n",
        "\n",
        "gpu_list = \",\".join((str(x) for x in range(N_GPUS))) + \",\"\n",
        "print(f\"Using GPUs: {gpu_list}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w7sLO53fz8kJ"
      },
      "outputs": [],
      "source": [
        "# Run training\n",
        "!(python main.py \\\n",
        "    -t \\\n",
        "    --base configs/stable-diffusion/pokemon.yaml \\\n",
        "    --gpus \"$gpu_list\" \\\n",
        "    --scale_lr False \\\n",
        "    --num_nodes 1 \\\n",
        "    --check_val_every_n_epoch 10 \\\n",
        "    --finetune_from \"$ckpt_path\" \\\n",
        "    data.params.batch_size=\"$BATCH_SIZE\" \\\n",
        "    lightning.trainer.accumulate_grad_batches=\"$ACCUMULATE_BATCHES\" \\\n",
        "    data.params.validation.params.n_gpus=\"$NUM_GPUS\" \\\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sk9wHFKgz8kJ"
      },
      "outputs": [],
      "source": [
        "# Run the model\n",
        "!(python scripts/txt2img.py \\\n",
        "    --prompt 'robotic cat with wings' \\\n",
        "    --outdir 'outputs/generated_pokemon' \\\n",
        "    --H 512 --W 512 \\\n",
        "    --n_samples 4 \\\n",
        "    --config 'configs/stable-diffusion/pokemon.yaml' \\\n",
        "    --ckpt 'path/to/your/checkpoint')"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "background_execution": "on",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3.8.10 64-bit",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.8.10"
    },
    "orig_nbformat": 4,
    "vscode": {
      "interpreter": {
        "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
